Yes, we can. In question 3 and 4 we demonstrated that we can use fewer predictors in order to reduce complexity of the model. Our model was within 2% of performance of full model with 1/4 of predictors. In order to run on web browser we need to make our model more compact in order to do so we can either create fewer tree or prune our trees more meaning choose a smaller value of max_node parameter. Now lets draw some plots to see how much each of them effects performance we’ll be measuring performance using F-Score and 10-fold cross validation on balanced data.
base_folder <- "/home/amirsalar/Diabetes-R/" # change this:)
data_folder <- paste0(base_folder, "data/")
output_folder <- paste0(base_folder, "output/")
diabetes_012_path <- paste0(data_folder,"diabetes_012_health_indicators_BRFSS2015.csv")
diabetes_binary_5050_path <- paste0(data_folder, "diabetes_binary_5050split_health_indicators_BRFSS2015.csv")
diabetes_binary_path <- paste0(data_folder, "diabetes_binary_health_indicators_BRFSS2015.csv")
library(data.table)
library(ggplot2)
library(readxl)
library(corrplot)
## corrplot 0.92 loaded
library(forcats)
library(gbm)
## Loaded gbm 2.1.8.1
library(caret)
## Loading required package: lattice
library(doParallel)
## Loading required package: foreach
## Loading required package: iterators
## Loading required package: parallel
library(randomForest)
## randomForest 4.7-1.1
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
##
## margin
library(pROC)
## Type 'citation("pROC")' for a citation.
##
## Attaching package: 'pROC'
## The following objects are masked from 'package:stats':
##
## cov, smooth, var
library(verification)
## Loading required package: fields
## Loading required package: spam
## Spam version 2.9-1 (2022-08-07) is loaded.
## Type 'help( Spam)' or 'demo( spam)' for a short introduction
## and overview of this package.
## Help for individual functions is also obtained by adding the
## suffix '.spam' to the function name, e.g. 'help( chol.spam)'.
##
## Attaching package: 'spam'
## The following objects are masked from 'package:base':
##
## backsolve, forwardsolve
## Loading required package: viridis
## Loading required package: viridisLite
##
## Try help(fields) to get started.
## Loading required package: boot
##
## Attaching package: 'boot'
## The following object is masked from 'package:lattice':
##
## melanoma
## Loading required package: CircStats
## Loading required package: MASS
## Loading required package: dtw
## Loading required package: proxy
##
## Attaching package: 'proxy'
## The following object is masked from 'package:spam':
##
## as.matrix
## The following objects are masked from 'package:stats':
##
## as.dist, dist
## The following object is masked from 'package:base':
##
## as.matrix
## Loaded dtw v1.23-1. See ?dtw for help, citation("dtw") for use in publication.
## Registered S3 method overwritten by 'verification':
## method from
## lines.roc pROC
##
## Attaching package: 'verification'
## The following object is masked from 'package:pROC':
##
## lines.roc
dt_012 <- data.table(read.csv(diabetes_012_path))
dt_5050 <- data.table(read.csv(diabetes_binary_5050_path))
dt <- data.table(read.csv(diabetes_binary_path))
dt_5050 <- dt_5050[, Diabetes_binary := as.factor(Diabetes_binary)]
dt <- dt[, Diabetes_binary := as.factor(Diabetes_binary)]
customRF <- list(label = "Random Forest",
library = "randomForest",
loop = NULL,
type = c("Classification", "Regression"),
parameters = data.frame(parameter = c("mtry", "ntree", "max_nodes"),
class = c("numeric", "numeric", "numeric"),
label = c("mtry", "ntree", "max_nodes")),
grid = function(x, y, len = NULL, search = "grid") {},
fit = function(x, y, wts, param, lev, last, classProbs, ...)
randomForest::randomForest(x, y, mtry = param$mtry, ntree=param$ntree, max_nodes=param$max_nodes, ...),
predict = function(modelFit, newdata, submodels = NULL)
if(!is.null(newdata)) predict(modelFit, newdata) else predict(modelFit),
prob = function(modelFit, newdata, submodels = NULL)
if(!is.null(newdata)) predict(modelFit, newdata, type = "prob") else predict(modelFit, type = "prob"),
predictors = function(x, ...) {
## After doing some testing, it looks like randomForest
## will only try to split on plain main effects (instead
## of interactions or terms like I(x^2).
varIndex <- as.numeric(names(table(x$forest$bestvar)))
varIndex <- varIndex[varIndex > 0]
varsUsed <- names(x$forest$ncat)[varIndex]
varsUsed
},
varImp = function(object, ...){
varImp <- randomForest::importance(object, ...)
if(object$type == "regression") {
if("%IncMSE" %in% colnames(varImp)) {
varImp <- data.frame(Overall = varImp[,"%IncMSE"])
} else {
varImp <- data.frame(Overall = varImp[,1])
}
}
else {
retainNames <- levels(object$y)
if(all(retainNames %in% colnames(varImp))) {
varImp <- varImp[, retainNames]
} else {
varImp <- data.frame(Overall = varImp[,1])
}
}
out <- as.data.frame(varImp, stringsAsFactors = TRUE)
if(dim(out)[2] == 2) {
tmp <- apply(out, 1, mean)
out[,1] <- out[,2] <- tmp
}
out
},
levels = function(x) x$classes,
tags = c("Random Forest", "Ensemble Model", "Bagging", "Implicit Feature Selection"),
sort = function(x) x[order(x[,1]),],
oob = function(x) {
out <- switch(x$type,
regression = c(sqrt(max(x$mse[length(x$mse)], 0)), x$rsq[length(x$rsq)]),
classification = c(1 - x$err.rate[x$ntree, "OOB"],
e1071::classAgreement(x$confusion[,-dim(x$confusion)[2]])[["kappa"]]))
names(out) <- if(x$type == "regression") c("RMSE", "Rsquared") else c("Accuracy", "Kappa")
out
})
cl <- makePSOCKcluster(4)
registerDoParallel(cl)
optVars <- c("GenHlth", "BMI","HighBP","Age","HighChol")
fitControl <- trainControl(## 10-fold CV
method = "cv",
number = 10,
summaryFunction = prSummary,
classProbs = TRUE,
)
dt_5050_optVars <- dt_5050[, c("Diabetes_binary", optVars), with=FALSE]
dt_optVars <- dt[, c("Diabetes_binary", optVars), with=FALSE]
random_indices <- sample(seq_len(nrow(dt_5050_optVars)), size = 40000, replace = FALSE)
smallerDT <- dt_5050_optVars[random_indices]
levels(smallerDT$Diabetes_binary) <- make.names(levels(smallerDT$Diabetes_binary), unique = TRUE)
random_forest.tune_grid <- expand.grid(mtry=2, ntree=c(30,50,100,150,(1:8) * 200), max_nodes=c(4,5,(2:8) * 3))
model.random_forest <- train(
Diabetes_binary ~ .,
data = smallerDT,
method = customRF,
trControl = fitControl,
metric = "F",
tuneGrid = random_forest.tune_grid
)
stopCluster(cl)
print(model.random_forest)
## Random Forest
##
## 40000 samples
## 5 predictor
## 2 classes: 'X0', 'X1'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 36001, 36000, 35999, 36000, 36000, 35999, ...
## Resampling results across tuning parameters:
##
## ntree max_nodes AUC Precision Recall F
## 30 4 0.4296771 0.7598996 0.7097045 0.7338836
## 30 5 0.4323901 0.7608304 0.7080614 0.7334216
## 30 6 0.4272761 0.7623567 0.7072144 0.7336831
## 30 9 0.4248502 0.7624518 0.7064677 0.7333363
## 30 12 0.4271042 0.7620215 0.7074627 0.7336396
## 30 15 0.4341234 0.7630290 0.7078125 0.7343078
## 30 18 0.4223639 0.7618217 0.7068168 0.7332330
## 30 21 0.4250657 0.7602002 0.7081111 0.7331402
## 30 24 0.4251558 0.7611110 0.7074143 0.7331495
## 50 4 0.4485141 0.7636312 0.7057702 0.7334946
## 50 5 0.4499457 0.7627651 0.7061196 0.7332699
## 50 6 0.4475790 0.7628898 0.7074639 0.7340714
## 50 9 0.4473470 0.7630900 0.7032315 0.7318448
## 50 12 0.4498547 0.7638793 0.7039284 0.7325963
## 50 15 0.4541569 0.7615952 0.7066674 0.7330088
## 50 18 0.4505764 0.7618360 0.7072149 0.7334527
## 50 21 0.4533244 0.7627322 0.7078128 0.7341552
## 50 24 0.4542761 0.7624387 0.7089076 0.7346291
## 100 4 0.4895316 0.7624807 0.7077631 0.7340275
## 100 5 0.4790910 0.7623458 0.7086086 0.7344287
## 100 6 0.4794613 0.7629502 0.7064184 0.7335682
## 100 9 0.4841884 0.7636326 0.7041780 0.7326389
## 100 12 0.4820931 0.7629577 0.7054724 0.7330077
## 100 15 0.4787478 0.7627845 0.7063190 0.7334149
## 100 18 0.4810731 0.7625970 0.7046753 0.7324315
## 100 21 0.4790265 0.7642218 0.7050742 0.7333899
## 100 24 0.4759172 0.7626834 0.7079617 0.7342663
## 150 4 0.4973032 0.7634611 0.7062199 0.7336435
## 150 5 0.4956566 0.7630528 0.7057708 0.7332253
## 150 6 0.4960494 0.7645167 0.7043270 0.7330725
## 150 9 0.4935245 0.7616367 0.7081109 0.7338354
## 150 12 0.4907099 0.7624111 0.7061198 0.7331284
## 150 15 0.4972525 0.7633125 0.7045262 0.7326721
## 150 18 0.4920661 0.7644963 0.7059701 0.7340043
## 150 21 0.4993081 0.7648330 0.7061694 0.7342825
## 150 24 0.4963849 0.7636271 0.7055718 0.7334108
## 200 4 0.5061760 0.7626405 0.7075141 0.7339795
## 200 5 0.5040876 0.7627724 0.7054726 0.7329162
## 200 6 0.5027084 0.7639991 0.7059702 0.7337781
## 200 9 0.5039133 0.7633486 0.7059203 0.7334519
## 200 12 0.5034640 0.7641821 0.7056711 0.7337313
## 200 15 0.5053846 0.7630068 0.7073141 0.7340206
## 200 18 0.5047230 0.7645944 0.7045267 0.7332623
## 200 21 0.5040800 0.7626658 0.7065680 0.7334954
## 200 24 0.5037613 0.7642032 0.7045262 0.7330803
## 400 4 0.5313524 0.7626213 0.7063190 0.7333224
## 400 5 0.5274412 0.7638656 0.7061690 0.7338349
## 400 6 0.5258995 0.7624996 0.7054225 0.7328009
## 400 9 0.5227255 0.7634204 0.7062194 0.7336385
## 400 12 0.5311845 0.7625098 0.7072149 0.7337516
## 400 15 0.5316204 0.7624146 0.7064677 0.7333093
## 400 18 0.5284766 0.7634029 0.7054720 0.7332431
## 400 21 0.5256814 0.7642990 0.7054224 0.7336210
## 400 24 0.5246840 0.7631405 0.7056713 0.7332019
## 600 4 0.5367111 0.7627810 0.7057710 0.7331155
## 600 5 0.5459058 0.7629300 0.7053227 0.7329263
## 600 6 0.5398669 0.7632733 0.7059204 0.7334046
## 600 9 0.5385706 0.7629150 0.7060202 0.7332983
## 600 12 0.5402042 0.7621261 0.7050739 0.7324217
## 600 15 0.5397073 0.7631222 0.7052733 0.7330080
## 600 18 0.5375774 0.7641678 0.7067667 0.7342854
## 600 21 0.5429564 0.7640193 0.7058208 0.7337179
## 600 24 0.5374523 0.7641004 0.7055219 0.7335991
## 800 4 0.5460358 0.7638474 0.7052233 0.7333009
## 800 5 0.5467790 0.7638196 0.7059203 0.7336745
## 800 6 0.5502117 0.7634155 0.7045262 0.7327284
## 800 9 0.5458296 0.7632977 0.7058210 0.7333685
## 800 12 0.5478169 0.7640447 0.7046756 0.7330856
## 800 15 0.5494910 0.7635015 0.7057711 0.7334528
## 800 18 0.5440922 0.7642481 0.7050242 0.7333806
## 800 21 0.5467461 0.7645100 0.7056216 0.7338316
## 800 24 0.5483556 0.7630889 0.7070160 0.7339094
## 1000 4 0.5524762 0.7627355 0.7063188 0.7333849
## 1000 5 0.5542491 0.7631358 0.7058708 0.7333353
## 1000 6 0.5504953 0.7635728 0.7058211 0.7334959
## 1000 9 0.5555450 0.7637452 0.7054227 0.7333576
## 1000 12 0.5522138 0.7631345 0.7062690 0.7335419
## 1000 15 0.5587181 0.7637338 0.7057213 0.7335173
## 1000 18 0.5559474 0.7625695 0.7059702 0.7331199
## 1000 21 0.5549789 0.7636128 0.7061194 0.7336914
## 1000 24 0.5520659 0.7635014 0.7054225 0.7332458
## 1200 4 0.5597211 0.7633348 0.7058209 0.7333914
## 1200 5 0.5547699 0.7635274 0.7059205 0.7335407
## 1200 6 0.5569060 0.7642355 0.7052732 0.7335054
## 1200 9 0.5592021 0.7631779 0.7063187 0.7335881
## 1200 12 0.5592068 0.7632548 0.7064682 0.7336975
## 1200 15 0.5585329 0.7633049 0.7056218 0.7332494
## 1200 18 0.5573477 0.7639010 0.7053230 0.7333971
## 1200 21 0.5563779 0.7635994 0.7060200 0.7336234
## 1200 24 0.5632530 0.7636553 0.7059203 0.7335932
## 1400 4 0.5647969 0.7626807 0.7060199 0.7332052
## 1400 5 0.5604588 0.7640645 0.7062192 0.7339366
## 1400 6 0.5634458 0.7632868 0.7051238 0.7329899
## 1400 9 0.5598219 0.7632151 0.7061695 0.7335232
## 1400 12 0.5636248 0.7631781 0.7057711 0.7333068
## 1400 15 0.5629775 0.7630676 0.7063685 0.7335697
## 1400 18 0.5612655 0.7634600 0.7053229 0.7331729
## 1400 21 0.5643391 0.7631817 0.7059702 0.7333947
## 1400 24 0.5617441 0.7637273 0.7055721 0.7334420
## 1600 4 0.5654979 0.7632000 0.7057710 0.7333028
## 1600 5 0.5627115 0.7633344 0.7061695 0.7335850
## 1600 6 0.5678724 0.7637280 0.7053728 0.7333271
## 1600 9 0.5621229 0.7640354 0.7061194 0.7338799
## 1600 12 0.5660340 0.7637730 0.7057212 0.7335418
## 1600 15 0.5670141 0.7634396 0.7060201 0.7335460
## 1600 18 0.5620525 0.7635382 0.7059206 0.7335350
## 1600 21 0.5612028 0.7630899 0.7059205 0.7333377
## 1600 24 0.5657514 0.7635978 0.7062691 0.7337548
##
## Tuning parameter 'mtry' was held constant at a value of 2
## F was used to select the optimal model using the largest value.
## The final values used for the model were mtry = 2, ntree = 50 and max_nodes
## = 24.
library("hexbin")
library("rayshader")
library(viridis)
ggrandomforest <- list()
ggrandomforest$fscore <- ggplot(model.random_forest$results) +
geom_point(aes(x=max_nodes,y=ntree,color=F)) +
scale_color_continuous(limits=c(min(model.random_forest$results$F), max(model.random_forest$results$F)))
plot_gg(ggrandomforest$fscore, width=3.5, multicore = TRUE, windowsize = c(1400,850), sunangle=180,
zoom = 0.45, phi = 40, theta = 340, scale = 200, shadow_intensity=0.3)
render_highquality(filename = "output/f-scorehigh_quality_plot.png")
ggrandomforest$auc <- ggplot(model.random_forest$results) +
geom_point(aes(x=max_nodes,y=ntree,color=AUC)) +
scale_color_continuous(limits=c(min(model.random_forest$results$AUC), max(model.random_forest$results$AUC)))
plot_gg(ggrandomforest$auc , width=3.5, multicore = TRUE, windowsize = c(1400,850), ssunangle=180,
zoom = 0.45, phi = 40, theta = 340, scale = 200, shadow_intensity=0.3)
render_highquality(filename = "output/auc-high_quality_plot.png")
ggrandomforest$precision <- ggplot(model.random_forest$results) +
geom_point(aes(x=max_nodes,y=ntree,color=Precision)) +
scale_color_continuous(limits=c(min(model.random_forest$results$Precision), max(model.random_forest$results$Precision)))
plot_gg(ggrandomforest$precision, width=3.5, multicore = TRUE, windowsize = c(1400,850), ssunangle=180,
zoom = 0.45, phi = 40, theta = 340, scale = 200, shadow_intensity=0.3)
render_highquality(filename = "output/precision-high_quality_plot.png")
ggrandomforest$recall <- ggplot(model.random_forest$results) +
geom_point(aes(x=max_nodes,y=ntree,color=Recall)) +
scale_color_continuous(limits=c(min(model.random_forest$results$Recall), max(model.random_forest$results$Recall)))
plot_gg(ggrandomforest$recall, width=3.5, multicore = TRUE, windowsize = c(1400,850), sunangle=180,
zoom = 0.45, phi = 40, theta = 340, scale = 200, shadow_intensity=0.3)
render_highquality(filename = "output/recall-high_quality_plot.png")
AUC
Precision
F